"""
Utility functions for plotting sympy functions.

See examples\mplot2d.py and examples\mplot3d.py for usable 2d and 3d
graphing functions using matplotlib.
"""

from numpy import repeat, arange, empty, ndarray, array
from sympy import Symbol, Basic, Real, Rational, I

def sample2d(f, x_args):
    """
    Samples a 2d function f over specified intervals and returns two
    arrays (X, Y) suitable for plotting with matlab (matplotlib)
    syntax. See examples\mplot2d.py.

    f is a function of one variable, such as x**2.
    x_args is an interval given in the form (var, min, max, n)
    """
    try:
        f = Basic.sympify(f)
    except:
        raise ValueError("f could not be interpretted as a SymPy function")
    try:
        x, x_min, x_max, x_n = x_args
    except:
        raise ValueError("x_args must be a tuple of the form (var, min, max, n)")

    x_l = float(x_max - x_min)
    x_d = x_l/float(x_n)
    X = arange(float(x_min), float(x_max)+x_d, x_d)

    Y = empty(len(X))
    for i in range(len(X)):
        try:
            Y[i] = float(f.subs(x, X[i]))
        except:
            Y[i] = None
    return X, Y

def sample3d(f, x_args, y_args):
    """
    Samples a 3d function f over specified intervals and returns three
    2d arrays (X, Y, Z) suitable for plotting with matlab (matplotlib)
    syntax. See examples\mplot3d.py.

    f is a function of two variables, such as x**2 + y**2.
    x_args and y_args are intervals given in the form (var, min, max, n)
    """
    x, x_min, x_max, x_n = None, None, None, None
    y, y_min, y_max, y_n = None, None, None, None
    try:
        f = Basic.sympify(f)
    except:
        raise ValueError("f could not be interpretted as a SymPy function")
    try:
        x, x_min, x_max, x_n = x_args
        y, y_min, y_max, y_n = y_args
    except:
        raise ValueError("x_args and y_args must be tuples of the form (var, min, max, intervals)")

    x_l = float(x_max - x_min)
    x_d = x_l/float(x_n)
    x_a = arange(float(x_min), float(x_max)+x_d, x_d)

    y_l = float(y_max - y_min)
    y_d = y_l/float(y_n)
    y_a = arange(float(y_min), float(y_max)+y_d, y_d)

    def meshgrid(x, y):
        """
        Taken from matplotlib.mlab.meshgrid.
        """
        x = array(x)
        y = array(y)
        numRows, numCols = len(y), len(x)
        x.shape = 1, numCols
        X = repeat(x, numRows, 0)

        y.shape = numRows, 1
        Y = repeat(y, numCols, 1)
        return X, Y

    X, Y = meshgrid(x_a, y_a)

    Z = ndarray((len(X), len(X[0])))
    for j in range(len(X)):
        for k in range(len(X[0])):
            try:
                Z[j][k] = float( f.subs(x, X[j][k]).subs(y, Y[j][k]) )
            except:
                Z[j][k] = 0
    return X, Y, Z

def sample(f, *var_args):
    """
    Samples a 2d or 3d function over specified intervals and returns
    a dataset suitable for plotting with matlab (matplotlib) syntax.
    Wrapper for sample2d and sample3d.

    f is a function of one or two variables, such as x**2.
    var_args are intervals for each variable given in the form (var, min, max, n)
    """
    if len(var_args) == 1:
        return sample2d(f, var_args[0])
    elif len(var_args) == 2:
        return sample3d(f, var_args[0], var_args[1])
    else:
        raise ValueError("Only 2d and 3d sampling are supported at this time.")
